{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip3 install tensorflow-gpu==1.14"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/husein/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/utils/optimize.py:187: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/models/research/neural_stack.py:52: The name tf.nn.rnn_cell.RNNCell is deprecated. Please use tf.compat.v1.nn.rnn_cell.RNNCell instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/utils/trainer_lib.py:111: The name tf.OptimizerOptions is deprecated. Please use tf.compat.v1.OptimizerOptions instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/utils/trainer_lib.py:111: The name tf.OptimizerOptions is deprecated. Please use tf.compat.v1.OptimizerOptions instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tensor2tensor.data_generators import problem_hparams\n",
    "from tensor2tensor.models import evolved_transformer\n",
    "from tensor2tensor.models import transformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('train-test-bpe.json') as fopen:\n",
    "    dataset = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = dataset['train_X']\n",
    "train_Y = dataset['train_Y']\n",
    "test_X = dataset['test_X']\n",
    "test_Y = dataset['test_Y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = transformer.transformer_tiny()\n",
    "english_size = 32000\n",
    "vietnam_size = 32000\n",
    "p_hparams = problem_hparams.test_problem_hparams(english_size, vietnam_size, params)\n",
    "params.problem_hparams = p_hparams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'infer'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.estimator.ModeKeys.PREDICT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_second_dim(x, desired_size):\n",
    "    padding = tf.tile([[[0.0]]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1], tf.shape(x)[2]], 0))\n",
    "    return tf.concat([x, padding], 1)\n",
    "\n",
    "class Translator:\n",
    "    def __init__(self, learning_rate = 1e-3, beam_size = 5, training = True):\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype=tf.int32)\n",
    "        self.training = tf.placeholder(tf.bool, None)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        seq_size = tf.shape(self.X)[1]\n",
    "        \n",
    "        maxlen_decode = 2 * tf.reduce_max(self.X_seq_len)\n",
    "        \n",
    "        x = tf.expand_dims(tf.expand_dims(self.X, -1), -1)\n",
    "        y = tf.expand_dims(tf.expand_dims(self.Y, -1), -1)\n",
    "        \n",
    "        print(x, y)\n",
    "        \n",
    "        features = {\n",
    "            \"inputs\": x,\n",
    "            \"targets\": y,\n",
    "            \"target_space_id\": tf.constant(1, dtype=tf.int32),\n",
    "        }\n",
    "        \n",
    "        if training:\n",
    "            mode = 'train'\n",
    "        else:\n",
    "            mode = 'infer'\n",
    "        \n",
    "        model = evolved_transformer.EvolvedTransformer(params,mode,p_hparams)\n",
    "        \n",
    "        logits, _ = model(features)\n",
    "        logits = tf.squeeze(tf.squeeze(logits, 2), 2)\n",
    "        \n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)\n",
    "        \n",
    "        y_t = tf.argmax(logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n",
    "        \n",
    "        with tf.variable_scope(tf.get_variable_scope(), reuse=True):\n",
    "        \n",
    "            self.fast_result = model._greedy_infer(features, maxlen_decode)[\"outputs\"]\n",
    "            self.beam_result = model._beam_decode_slow(\n",
    "                features, maxlen_decode, beam_size=beam_size, top_beams=1, alpha=1.0)[\"outputs\"]\n",
    "            \n",
    "            print(self.fast_result, self.beam_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"ExpandDims_1:0\", shape=(?, ?, 1, 1), dtype=int32) Tensor(\"ExpandDims_3:0\", shape=(?, ?, 1, 1), dtype=int32)\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/utils/t2t_model.py:2262: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/utils/t2t_model.py:2262: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'train'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'train'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/layers/common_layers.py:100: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/layers/common_layers.py:100: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/utils/t2t_model.py:245: The name tf.summary.text is deprecated. Please use tf.compat.v1.summary.text instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/utils/t2t_model.py:245: The name tf.summary.text is deprecated. Please use tf.compat.v1.summary.text instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:num_trainable_top_decoder_layers is negative so training all weights.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:num_trainable_top_decoder_layers is negative so training all weights.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/autograph/converters/directives.py:117: The name tf.get_variable_scope is deprecated. Please use tf.compat.v1.get_variable_scope instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/autograph/converters/directives.py:117: The name tf.get_variable_scope is deprecated. Please use tf.compat.v1.get_variable_scope instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/autograph/converters/directives.py:117: The name tf.summary.scalar is deprecated. Please use tf.compat.v1.summary.scalar instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/autograph/converters/directives.py:117: The name tf.summary.scalar is deprecated. Please use tf.compat.v1.summary.scalar instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py:255: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py:255: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/utils/t2t_model.py:1373: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/utils/t2t_model.py:1373: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/utils/t2t_model.py:417: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/utils/t2t_model.py:417: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_32000_128.bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_32000_128.bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_32000_128.targets_bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_32000_128.targets_bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/layers/common_layers.py:416: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/layers/common_layers.py:416: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/models/transformer.py:96: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/models/transformer.py:96: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/models/evolved_transformer.py:170: The name tf.layers.Dense is deprecated. Please use tf.compat.v1.layers.Dense instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/models/evolved_transformer.py:170: The name tf.layers.Dense is deprecated. Please use tf.compat.v1.layers.Dense instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/models/evolved_transformer.py:195: The name tf.layers.Conv1D is deprecated. Please use tf.compat.v1.layers.Conv1D instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/models/evolved_transformer.py:195: The name tf.layers.Conv1D is deprecated. Please use tf.compat.v1.layers.Conv1D instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/models/evolved_transformer.py:215: The name tf.layers.SeparableConv1D is deprecated. Please use tf.compat.v1.layers.SeparableConv1D instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/models/evolved_transformer.py:215: The name tf.layers.SeparableConv1D is deprecated. Please use tf.compat.v1.layers.SeparableConv1D instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/layers/common_attention.py:1324: The name tf.summary.image is deprecated. Please use tf.compat.v1.summary.image instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/layers/common_attention.py:1324: The name tf.summary.image is deprecated. Please use tf.compat.v1.summary.image instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/models/evolved_transformer.py:614: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dense instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/models/evolved_transformer.py:614: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dense instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_32000_128.top\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_32000_128.top\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py:1354: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py:1354: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/models/transformer.py:1226: to_int64 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor/models/transformer.py:1226: to_int64 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_32000_128.bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_32000_128.bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_32000_128.targets_bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_32000_128.targets_bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_32000_128.top\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_32000_128.top\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"evolved_transformer_1/while/Exit_3:0\", shape=(?, ?), dtype=int64) Tensor(\"strided_slice_12:0\", shape=(?, ?), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Translator()\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int):\n",
    "    padded_seqs = []\n",
    "    seq_lens = []\n",
    "    max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
    "    for sentence in sentence_batch:\n",
    "        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        seq_lens.append(len(sentence))\n",
    "    return padded_seqs, seq_lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[10.888551, None]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_x, _ = pad_sentence_batch(train_X[: 5], 0)\n",
    "batch_y, _ = pad_sentence_batch(train_Y[: 5], 0)\n",
    "\n",
    "sess.run([model.cost, model.optimizer], feed_dict = {model.X: batch_x, model.Y: batch_y})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 96\n",
    "epoch = 20\n",
    "PAD = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:55<00:00,  5.90it/s, accuracy=0.278, cost=4.61]  \n",
      "minibatch loop: 100%|██████████| 30/30 [00:03<00:00,  9.21it/s, accuracy=0.278, cost=4.62]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:27,  6.68it/s, accuracy=0.347, cost=3.96]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, training avg loss 4.941576, training avg acc 0.238279\n",
      "epoch 1, testing avg loss 4.067009, testing avg acc 0.331084\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:44<00:00,  6.18it/s, accuracy=0.347, cost=3.92]\n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.33it/s, accuracy=0.354, cost=3.96]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:26,  6.73it/s, accuracy=0.438, cost=3.22]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, training avg loss 3.502500, training avg acc 0.398134\n",
      "epoch 2, testing avg loss 3.405555, testing avg acc 0.413942\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.17it/s, accuracy=0.389, cost=3.53]\n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.34it/s, accuracy=0.371, cost=3.69]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:28,  6.65it/s, accuracy=0.483, cost=2.87]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, training avg loss 2.972033, training avg acc 0.460962\n",
      "epoch 3, testing avg loss 3.154716, testing avg acc 0.446736\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.16it/s, accuracy=0.406, cost=3.21]\n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.32it/s, accuracy=0.393, cost=3.56]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:25,  6.75it/s, accuracy=0.486, cost=2.7]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, training avg loss 2.704851, training avg acc 0.492125\n",
      "epoch 4, testing avg loss 3.016861, testing avg acc 0.464762\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.16it/s, accuracy=0.408, cost=3.06]\n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.34it/s, accuracy=0.408, cost=3.47]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:26,  6.73it/s, accuracy=0.508, cost=2.56]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, training avg loss 2.542467, training avg acc 0.510391\n",
      "epoch 5, testing avg loss 2.940933, testing avg acc 0.474570\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.16it/s, accuracy=0.431, cost=2.89]\n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.32it/s, accuracy=0.429, cost=3.35]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:26,  6.73it/s, accuracy=0.535, cost=2.47]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 6, training avg loss 2.431287, training avg acc 0.523740\n",
      "epoch 6, testing avg loss 2.895403, testing avg acc 0.479958\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.16it/s, accuracy=0.446, cost=2.73]\n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.34it/s, accuracy=0.425, cost=3.35]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:26,  6.74it/s, accuracy=0.536, cost=2.38]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 7, training avg loss 2.348464, training avg acc 0.533226\n",
      "epoch 7, testing avg loss 2.861190, testing avg acc 0.485580\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.15it/s, accuracy=0.466, cost=2.63]\n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.37it/s, accuracy=0.432, cost=3.31]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:28,  6.67it/s, accuracy=0.542, cost=2.32]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 8, training avg loss 2.283155, training avg acc 0.541155\n",
      "epoch 8, testing avg loss 2.839225, testing avg acc 0.486344\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.15it/s, accuracy=0.468, cost=2.6] \n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.43it/s, accuracy=0.433, cost=3.28]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:27,  6.69it/s, accuracy=0.544, cost=2.28]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 9, training avg loss 2.230554, training avg acc 0.547377\n",
      "epoch 9, testing avg loss 2.830232, testing avg acc 0.488518\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.16it/s, accuracy=0.468, cost=2.5] \n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.40it/s, accuracy=0.426, cost=3.28]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:28,  6.65it/s, accuracy=0.552, cost=2.23]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10, training avg loss 2.185407, training avg acc 0.552695\n",
      "epoch 10, testing avg loss 2.821573, testing avg acc 0.488192\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.16it/s, accuracy=0.48, cost=2.4]  \n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.32it/s, accuracy=0.43, cost=3.28] \n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:24,  6.78it/s, accuracy=0.561, cost=2.2]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 11, training avg loss 2.148526, training avg acc 0.557104\n",
      "epoch 11, testing avg loss 2.811038, testing avg acc 0.491911\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.16it/s, accuracy=0.475, cost=2.41]\n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.36it/s, accuracy=0.44, cost=3.31] \n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:26,  6.71it/s, accuracy=0.557, cost=2.17]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 12, training avg loss 2.115184, training avg acc 0.561326\n",
      "epoch 12, testing avg loss 2.808708, testing avg acc 0.492841\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.16it/s, accuracy=0.503, cost=2.32]\n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.32it/s, accuracy=0.439, cost=3.27]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:29,  6.61it/s, accuracy=0.553, cost=2.14]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 13, training avg loss 2.086918, training avg acc 0.564878\n",
      "epoch 13, testing avg loss 2.800548, testing avg acc 0.494502\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.16it/s, accuracy=0.515, cost=2.25]\n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.33it/s, accuracy=0.45, cost=3.25] \n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:23,  6.81it/s, accuracy=0.566, cost=2.11]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 14, training avg loss 2.060728, training avg acc 0.568022\n",
      "epoch 14, testing avg loss 2.789417, testing avg acc 0.495979\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.17it/s, accuracy=0.508, cost=2.18]\n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.39it/s, accuracy=0.432, cost=3.26]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:26,  6.73it/s, accuracy=0.569, cost=2.12]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 15, training avg loss 2.037601, training avg acc 0.570928\n",
      "epoch 15, testing avg loss 2.787962, testing avg acc 0.496062\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.16it/s, accuracy=0.493, cost=2.23]\n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.32it/s, accuracy=0.434, cost=3.27]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:27,  6.69it/s, accuracy=0.57, cost=2.09]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 16, training avg loss 2.016497, training avg acc 0.573804\n",
      "epoch 16, testing avg loss 2.790819, testing avg acc 0.496480\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.17it/s, accuracy=0.517, cost=2.2] \n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.40it/s, accuracy=0.445, cost=3.28]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:30,  6.60it/s, accuracy=0.573, cost=2.06]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 17, training avg loss 1.998177, training avg acc 0.575945\n",
      "epoch 17, testing avg loss 2.784592, testing avg acc 0.497487\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.16it/s, accuracy=0.506, cost=2.21]\n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.38it/s, accuracy=0.432, cost=3.27]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:23,  6.83it/s, accuracy=0.567, cost=2.04]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 18, training avg loss 1.980102, training avg acc 0.578285\n",
      "epoch 18, testing avg loss 2.787682, testing avg acc 0.497013\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.16it/s, accuracy=0.525, cost=2.1] \n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.34it/s, accuracy=0.447, cost=3.23]\n",
      "minibatch loop:   0%|          | 1/1389 [00:00<03:22,  6.87it/s, accuracy=0.582, cost=2]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 19, training avg loss 1.964271, training avg acc 0.580201\n",
      "epoch 19, testing avg loss 2.790944, testing avg acc 0.495617\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1389/1389 [03:45<00:00,  6.15it/s, accuracy=0.511, cost=2.16]\n",
      "minibatch loop: 100%|██████████| 30/30 [00:02<00:00, 11.33it/s, accuracy=0.436, cost=3.23]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 20, training avg loss 1.950056, training avg acc 0.582312\n",
      "epoch 20, testing avg loss 2.785332, testing avg acc 0.497402\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import tqdm\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'minibatch loop')\n",
    "    train_loss, train_acc, test_loss, test_acc = [], [], [], []\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_X))\n",
    "        batch_x, seq_x = pad_sentence_batch(train_X[i : index], PAD)\n",
    "        batch_y, seq_y = pad_sentence_batch(train_Y[i : index], PAD)\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y}\n",
    "        accuracy, loss, _ = sess.run([model.accuracy,model.cost,model.optimizer],\n",
    "                                    feed_dict = feed)\n",
    "        train_loss.append(loss)\n",
    "        train_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "        \n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_X))\n",
    "        batch_x, seq_x = pad_sentence_batch(test_X[i : index], PAD)\n",
    "        batch_y, seq_y = pad_sentence_batch(test_Y[i : index], PAD)\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y,}\n",
    "        accuracy, loss = sess.run([model.accuracy,model.cost],\n",
    "                                    feed_dict = feed)\n",
    "\n",
    "        test_loss.append(loss)\n",
    "        test_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    print('epoch %d, training avg loss %f, training avg acc %f'%(e+1,\n",
    "                                                                 np.mean(train_loss),np.mean(train_acc)))\n",
    "    print('epoch %d, testing avg loss %f, testing avg acc %f'%(e+1,\n",
    "                                                              np.mean(test_loss),np.mean(test_acc)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import youtokentome as yttm\n",
    "bpe = yttm.BPE(model='vietnam.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "greedy = sess.run(model.fast_result, feed_dict = {model.X: test_X[:1]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([1377,\n",
       "  577,\n",
       "  278,\n",
       "  281,\n",
       "  323,\n",
       "  627,\n",
       "  1479,\n",
       "  321,\n",
       "  316,\n",
       "  1196,\n",
       "  353,\n",
       "  2097,\n",
       "  1485,\n",
       "  818,\n",
       "  631,\n",
       "  288,\n",
       "  317,\n",
       "  784,\n",
       "  807,\n",
       "  575,\n",
       "  942,\n",
       "  365,\n",
       "  620,\n",
       "  256,\n",
       "  353,\n",
       "  403,\n",
       "  346,\n",
       "  365,\n",
       "  440,\n",
       "  288,\n",
       "  2097,\n",
       "  1485,\n",
       "  896,\n",
       "  1687,\n",
       "  873,\n",
       "  656,\n",
       "  1230,\n",
       "  545,\n",
       "  327,\n",
       "  12091,\n",
       "  2041,\n",
       "  859,\n",
       "  527,\n",
       "  513,\n",
       "  291,\n",
       "  279,\n",
       "  670,\n",
       "  858,\n",
       "  1517,\n",
       "  827,\n",
       "  2944,\n",
       "  299,\n",
       "  794,\n",
       "  334,\n",
       "  975,\n",
       "  800,\n",
       "  1195,\n",
       "  723,\n",
       "  291,\n",
       "  670,\n",
       "  858,\n",
       "  256,\n",
       "  334,\n",
       "  986,\n",
       "  270,\n",
       "  969,\n",
       "  321,\n",
       "  1208,\n",
       "  449,\n",
       "  582,\n",
       "  1522,\n",
       "  437,\n",
       "  429,\n",
       "  2207,\n",
       "  723,\n",
       "  11296,\n",
       "  19468,\n",
       "  437,\n",
       "  5521,\n",
       "  1762,\n",
       "  415,\n",
       "  316,\n",
       "  449,\n",
       "  596,\n",
       "  256,\n",
       "  288,\n",
       "  2097,\n",
       "  1485,\n",
       "  818,\n",
       "  631,\n",
       "  327,\n",
       "  12091,\n",
       "  859,\n",
       "  527,\n",
       "  888,\n",
       "  670,\n",
       "  858,\n",
       "  545,\n",
       "  270,\n",
       "  302,\n",
       "  710,\n",
       "  512,\n",
       "  1701,\n",
       "  453,\n",
       "  333,\n",
       "  333,\n",
       "  670,\n",
       "  858,\n",
       "  545,\n",
       "  512,\n",
       "  513,\n",
       "  348,\n",
       "  2385,\n",
       "  3354,\n",
       "  270,\n",
       "  401,\n",
       "  381,\n",
       "  576,\n",
       "  334,\n",
       "  297,\n",
       "  318,\n",
       "  1056,\n",
       "  699,\n",
       "  392],\n",
       " [1377,\n",
       "  365,\n",
       "  440,\n",
       "  351,\n",
       "  381,\n",
       "  321,\n",
       "  316,\n",
       "  1196,\n",
       "  353,\n",
       "  288,\n",
       "  317,\n",
       "  784,\n",
       "  807,\n",
       "  620,\n",
       "  256,\n",
       "  403,\n",
       "  281,\n",
       "  885,\n",
       "  896,\n",
       "  873,\n",
       "  656,\n",
       "  1230,\n",
       "  888,\n",
       "  513,\n",
       "  827,\n",
       "  291,\n",
       "  670,\n",
       "  986,\n",
       "  858,\n",
       "  582,\n",
       "  279,\n",
       "  739,\n",
       "  2207,\n",
       "  723,\n",
       "  348,\n",
       "  1080,\n",
       "  1070,\n",
       "  415,\n",
       "  449,\n",
       "  596,\n",
       "  818,\n",
       "  631,\n",
       "  1137,\n",
       "  337,\n",
       "  429,\n",
       "  1195,\n",
       "  270,\n",
       "  302,\n",
       "  710,\n",
       "  512,\n",
       "  453,\n",
       "  258,\n",
       "  336,\n",
       "  401,\n",
       "  2385,\n",
       "  3354,\n",
       "  576,\n",
       "  334,\n",
       "  545,\n",
       "  699,\n",
       "  392,\n",
       "  1036,\n",
       "  268,\n",
       "  969,\n",
       "  1635,\n",
       "  988,\n",
       "  406,\n",
       "  1111,\n",
       "  326,\n",
       "  1294,\n",
       "  1618,\n",
       "  657,\n",
       "  1056,\n",
       "  333,\n",
       "  1028,\n",
       "  541,\n",
       "  327,\n",
       "  1827,\n",
       "  527,\n",
       "  346])"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "actual = test_Y[0]\n",
    "predicted = list(dict.fromkeys(greedy[0].tolist()))\n",
    "\n",
    "actual, predicted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Làm sao tôi có thể trình bày trong <NUM> phút về sợi dây liên kết những người phụ nữ qua ba thế hệ , về việc làm thế nào những sợi dây mạnh mẽ đáng kinh ngạc ấy đã níu chặt lấy cuộc sống của một cô bé bốn tuổi co quắp với đứa em gái nhỏ của cô bé , với mẹ và bà trong suốt năm ngày đêm trên con thuyền nhỏ lênh đênh trên Biển Đông hơn <NUM> năm trước , những sợi dây liên kết đã níu lấy cuộc đời cô bé ấy và không bao giờ rời đi - - cô bé ấy giờ sống ở San Francisco và đang nói chuyện với các bạn hôm nay ?']"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bpe.decode(actual)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Làm thế nào để nói trong <NUM> phút về những người phụ nữ hệ , việc có sức mạnh đáng kinh ngạc đời sống tuổi của cô mẹ bé ngày một chiếc thuyền nhỏ ở Trung Quốc hơn năm trước liên kết giữ cho con gái và không bao giờ đi . \" đang San Francisco chuyện với ấy nay ? bây là bà già tháng lại giao này Cô nam nữa hôm - đàn ông đã nắm cuộc làm']"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bpe.decode(predicted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
